%% Calculate X_hat 

function [P_hat_n,P_hat_nj,X_hat_mnj,pi_c_nj1,PC_hat_n,...
    pseudo_intermediates, pseudo_final_goods] = ...
    ExpenditureCESOlig_fun(P_hat_mnj,pi_M1,pi_l1,w_hat,...
    pi_M_f1,pi_l_f1,pi_c_mnj0,PC_n,pi_c_nj,...
    X_mnj,pi_c_mnj1,zeta_mnj,pi_nkjf1,pi_nki1,X_hat_nki)

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Author:Produced for JdG,AL,IM by Christopher Evans at UPF 
%
% Program: This function calculates the expenditure matrix X hat, the
% change in expenditure needed for the program to converge. Initially we
% shock this equation (23) in master notes with zeta_hat=0 to calculate the
% no wedge X_hat and then we will use X_hat to start the counterfactual
% analysis
%
% NEW: For all code, we will change the subscripts to match those in
% master_notes3.pdf and the draft. I.e., we will change to {mn,ij} pairs,
% which represent country pair {mn} and sector pair {ij}. Global variables 
% changed and slight modification to file structure for calling up data.
% We also get rid of gamma_type variable. What previously _2.csv. files
% slightly as well as adjusting in our new directory.
%
% Compared to C-D, we now also feed in updated next period values of the
% labor and intermediate shares calculated in TradeShare function.
%
% Last Updated: 30/01/2019 by JDG
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

global alphaT rhoT lambdaT etaT rho sigma lambda eta varphi ...
    tol maxit Pi_hat_n D_hat_n Xi_hat_mnjf Xi_hat_mnj a_hat_f a_hat ...
    N J FI_J france countrysecD firmsecD_sorted start start_sorted ...
    sum_by_country_dummy ISO ... 
    sL_n sPi_n sD_n alpha_WIOD_j alpha_WIOD_francej labour_share_PWT

%% Calculating P_hat_nj

% Need to expand out sigma to be able to multiply easily

sigma_nj_repmat=repmat(sigma,[N 1]);
sigma_mnj_repmat=repmat(sigma_nj_repmat,[1 N]);

% P_hat_pi_c: Multiply the price hat with the share of final good
%             consumption 

%P_hat_pi_c= P_hat_mnj.^(1-sigma_mnj_repmat).*pi_c_mnj0;

% sum_P_hat_pi_c: Summing the above equation by country i 

%sum_P_hat_pi_c=sum_by_country_dummy*(P_hat_mnj.^(1-sigma_mnj_repmat).*pi_c_mnj0);

% P_hat_nj: Taking the expression to the power of 1/(1-rho). It gives us
%           the new Price hat but for country n and sector j

P_hat_nj= (sum_by_country_dummy*(P_hat_mnj.^(1-sigma_mnj_repmat).*pi_c_mnj0)).^(1./(1-repmat(sigma,[1 N])));

% Set the value =1 if it is equal to infinity. The infinity problem comes 
% as we are putting 0^(negative number), the zero is due to no final goods
% purchased in that sector and country. 

P_hat_nj(P_hat_nj==inf)=1;

%% Calculating P_hat_n

% P_hat_varphi_j: Now we are trying to work out P_hat_n, the price hat at
%                the country level. For this we take logs to change the
%                product operaton into a summation. varphi_j is assumed to
%                be pi_c_nj and is assumed not to change over time. 

%P_hat_varphi_j=pi_c_nj.*log(P_hat_nj);

% P_hat_n: Sums by j (as we logged it) and then takes the exp, this
%          transforms it back since we logged it.

P_hat_n=exp(sum((pi_c_nj.*log(P_hat_nj)),1));

%% Calculating pi_c_nj1: a constant given cobb-douglas preferences

% Will simply set to previous period value

pi_c_nj1=pi_c_nj;

%% Calculate X_hat_mnj

% zeta_hat_mnj: Change in the wedge, assume this is zero (trick) to find
% the change in X when there are no wedges in the market clearing condition

zeta_hat_mnj=zeros(N*J,N); 

Xmax = 1;
it    = 1;     
%fprintf('      EXPENDITURE ITERATION \n');

while (it <= maxit) && (Xmax > tol)
   %% All countries
   
   % X_pi: Exports from country n sector i. This is calculated by summing
   % over the importing countries k and also multiplying by pi, however for
   % sectors this should be 1 (as firm share sums up to 1)
   
   X_pi = sum(pi_nki1.*X_hat_nki.*X_mnj,2);
   
   % Making the previous result diagonal so that we can multiply it by the
   % gamma matrix. Diagonal so elements of value on the diagonal and the 
   % rest are zero
   
   diag_X_pi=diag(X_pi); 
   
   for i=1:N
      for n=1:N
         % for a given m and n gamma is a [JxJ] matrix, the X_pi is [J x1]
         % but if we diagonalize it will be a [JXJ] wilth the value along
         % the diagonal and zeros elsewhere.
         
         pi_M1_X_pi(start(i):start(i+1)-1,start(n):start(n+1)-1)= ...
            pi_M1(start(i):start(i+1)-1,start(n):start(n+1)-1)*...
            diag_X_pi(start(n):start(n+1)-1,start(n):start(n+1)-1);
      end
   end
    
   % oneminus_pi_l1_matrix: This is one minus the labor share, need to 
   % scale it up to JN*JN This is done by  repmat to
   % extend the number of columns.
   
%   oneminus_pi_l1_matrix=repmat((1-pi_l1),[1 N*J]);
    
   % one_minus_pil1_piM1_X_pi: Multiply two matrices of shares
   
   %one_minus_pil1_piM1_X_pi= pi_M1_X_pi.*(repmat((1-pi_l1),[1 N*J]))';
   
   % pseudo_intermediates_mnjj: This is the pseudo Z matrix for country m
   % sector j trade to country n sector j. We do this so we can observe
   % changes in inputs 
   
   rho_nj_repmat=repmat(rho,[N 1]);
   rho_mnjj_repmat=repmat(rho_nj_repmat',[N*J 1]);
   pseudo_intermediates_mnjj=((rho_mnjj_repmat-1)./rho_mnjj_repmat).*...
              (pi_M1_X_pi.*(repmat((1-pi_l1),[1 N*J]))'); 
   
   % Summing over , meaning we should have a [NJ N] afterwards
   
   for i=1:N     
      % loops through each country i summing the j sectors
      pseudo_intermediates(:,i)=...
               sum(pseudo_intermediates_mnjj(:,start(i):start(i+1)-1),2);     
      
   end
   
   % square_brackets: Square barckets from equation 14, should be equal to 
   %                  1 when we have no shocks.
   
   PC_hat_n= w_hat.*(w_hat./P_hat_n).^(1/(varphi-1)).*sL_n+...
                    Pi_hat_n.*sPi_n+D_hat_n.*sD_n;
   
   %PC_hat_n = square_brackets;
   % square_brackets_repmat: Replicating the square bracket such that it 
   % spans trade from country i to n in sector j, therefore we replicate it  
   % by the number of countries (i) and the number of sectors (j)
   
   square_brackets_repmat= repmat(PC_hat_n,[N*J 1])  ;
   
   % PC_n_repmat: We also need to replicate the final good consumption for
   %             country n, this is so we can multiply it at the inj level
   
   PC_n_repmat= repmat(PC_n, [N*J 1]);
   
   % pi_nj1_repmat: Final good consumption share for country n in sector
   %                      j also needs to be replicated, increasing the size 
   %                      of the matrix by countries n
   
   pi_c_nj1_repmat=repmat(pi_c_nj1,[N 1]) ;
   
   
   % pseudo_final_goods: Putting the replicated matrices and square bracket
   %                     matrix together to create the new updated pseudo 
   %                     final good matrix
   
   pseudo_final_goods =pi_c_mnj1.*pi_c_nj1_repmat.*...
square_brackets_repmat.*PC_n_repmat ;
   
   %% France (french firms)
   
   % X_mnjf: Multiplies the expenditure matrix for france with the firm dummy
   %         to create a [F x N matrix]
   
   X_mnjf = firmsecD_sorted*(X_mnj(start(france):start(france+1)-1,:).*...
      X_hat_nki(start(france):start(france+1)-1,:));
   
   % pi_X_nkjf: Multiplying firm shares of expenditures (pi_nkj) by the
   %            expenditure matrix that has been scaled up by the firm 
   %            sector dummy
   repmat_sigma = repmat(firmsecD_sorted * sigma,[1,N]);
   repmat_rho = repmat(firmsecD_sorted * rho,[1,N]);
   mu_mnjf1 = repmat_sigma.*repmat_rho./(repmat_sigma.*(repmat_rho-1)-(repmat_rho-repmat_sigma).*pi_nkjf1);
   %pi_X_nkjf=pi_nkjf1.*X_mnjf./mu_mnjf1 ;
   
   % pi_X_njf: This is the sum over country k, should result in a [Fx1] vector
   
   %pi_X_njf=sum(pi_nkjf1.*X_mnjf./mu_mnjf1,2);
   
   %one_minus_pi_l_f1: Labor shares for each firm = 1-pi_l_f1
   
   %one_minus_pi_l_f1=1-pi_l_f1;
   
   % Multiplying the previous summed equation by this labor share. This is
   % done before multiplying by pi_M_f1, which will come later.
   
   pi_X_laborshare_njf=sum(pi_nkjf1.*X_mnjf./mu_mnjf1,2).*(1-pi_l_f1);
   
   clear pi_X_njf X_mnjf
   

   % pi_X_laborshare_njf: Repurposing the matrix such that it spans across the
   %                      j sectors of country m. It will have the same value
   %                      j and make the following computation easier

   %pi_X_laborshare_njf_repmat= repmat(pi_X_laborshare_njf, [1 J*N]);     

   %gamma_pi_X_laborshare_njf = repmat(pi_X_laborshare_njf, [1 J*N]).*pi_M_f1;
   french_Z_mnj = sum(repmat(pi_X_laborshare_njf, [1 J*N]).*pi_M_f1,1);
  
   % stacked_french_Z_mnj: Stacking the previous result puts it in the same
   %                       format as the rest of the Matlab file, such that we 
   %                       will have a [J*N 1] vector, where the 1 is due to
   %                       looking at only one n, france.

   stacked_french_Z_mnj=reshape(french_Z_mnj, [N*J 1]);
   
   %% Combining France with rest of world and calculating the new X hat guess
   
   % pseudo_intermediates: assigning n=france to the pseudo intermediates
   %                       calculated for the rest of the world
   
   pseudo_intermediates(:,france)=stacked_french_Z_mnj;
   
   % X_hat_mnj_X_mnj: The new value of X_hat, given by equation (23) this
   %                  includes zeta_hat, however, zeta_hat is zero
   
   X_hat_mnj_X_mnj=pseudo_final_goods + pseudo_intermediates + ...
      zeta_hat_mnj.*zeta_mnj  ;
   
   X_hat_mnj_X_mnj=X_hat_mnj_X_mnj+0.000000000000001.*(X_hat_mnj_X_mnj==0);
   X_mnj=X_mnj+0.000000000000001.*(X_mnj==0);
   
   % Calculatig the X hat (next periods expenditure divided by this periods) 
   % by using. Likely have to come back to this so that X_inj isn't zero.
   
   X_hat_mnj=X_hat_mnj_X_mnj./X_mnj;
   
   % Calculating the difference between the guess and the updated expenditure
   % matrix and then updating the guess. Firstly make the matrix into a vector
   % this is so we can take the norm easier later.
    
   X_hat_nki_vec=reshape(X_hat_nki, [N*N*J 1]);
   X_hat_mnj_vec= reshape(X_hat_mnj,[N*N*J 1]);
   
   X_hat_nki_vec=X_hat_nki_vec+.00000000000001.*(X_hat_nki_vec==0);
   
   X_diff=abs(X_hat_nki_vec-X_hat_mnj_vec)./X_hat_nki_vec;  
      
   % new guess, as the change should be the same
   
   X_hat_nki=X_hat_mnj;
   
   % maximum distance between the guess
   %Xmax=max(max(max(X_diff)));

   Xmax=norm(X_diff);

%    % Small interation expenditures relative to previous step given a nu_X, 
%    % where nu_X is the adjustment we might make
%    
%    nu_X = .5;
%    
%    X_diff = (X_hat_nki_vec-X_hat_mnj_vec)./X_hat_nki_vec;
%    
%    X_hat_nki_vec = X_hat_mnj_vec + nu_X*X_hat_mnj_vec.*X_diff; % Iteration of new X
%    
%    X_diff1 = (X_hat_nki_vec-X_hat_mnj_vec)./X_hat_nki_vec;
%    
%    X_hat_nki = reshape(X_hat_nki_vec,[N*J,N]);
%    
%    % Calculate distance as max between non-adjusted and adjusted new X_hat 
%    % and old X_hat
% 
%    Xmax = max(max(abs(X_diff)),max(abs(X_diff1)));
  
   it = it + 1;
   
   fprintf('   EXPENDITURE ITERATION=%d        Max expenditure distance=%d  \n',it-1,Xmax);
end

end